from Global_variable_setting import optimizer_initial_lr, loss_values, metrics_values
from Global_variable_setting import weight_decay, image_size, patch_size, num_patches
from Global_variable_setting import projection_dim, num_heads, transformer_hidden_units
from Global_variable_setting import transformer_layers, mlp_head_units
from Global_variable_setting import dropout, emb_dropout

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras import Sequential
import tensorflow.keras.layers as nn

from tensorflow import einsum
from einops import rearrange
from einops.layers.tensorflow import Rearrange


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


class PreNorm(Layer):
    def __init__(self, fn):
        super(PreNorm, self).__init__()

        self.norm = nn.LayerNormalization()
        self.fn = fn

    def call(self, x):
        return self.fn(self.norm(x))


class MLP(Layer):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super(MLP, self).__init__()

        self.net = Sequential([
            nn.Dense(units=hidden_dim),
            nn.LeakyReLU(),
            # nn.ReLU(),
            nn.Dropout(rate=dropout),
            nn.Dense(units=dim),
            nn.Dropout(rate=dropout),
        ])
        # inputs = tf.keras.layers.Input
        # x = inputs
        # for units in hidden_units:
        #     x = tf.keras.layers.Dense(units, activation='relu')(x)
        #     x = nn.Dropout(rate=dropout)(x)
        # self.net = Model(input=inputs, output=x)

    def call(self, x):
        return self.net(x)


# def mlp(x, hidden_units, dropout=0.0):
#     for units in hidden_units:
#         x = tf.keras.layers.Dense(units, activation='relu')(x)
#         x = nn.Dropout(rate=dropout)(x)
#     return x


class Attention(Layer):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
        super(Attention, self).__init__()

        inner_dim = heads * dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax()
        self.to_qkv = nn.Dense(units=inner_dim*3, use_bias=False)

        if project_out:
            self.to_out = [
                nn.Dense(units=dim),
                nn.Dropout(rate=dropout)
            ]
        else:
            self.to_out = []

        self.to_out = Sequential(self.to_out)

    def call(self, x):
        qkv = self.to_qkv(x)
        qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        # dots = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) * self.scale
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = self.attend(dots)

        # x = tf.matmul(attn, v)
        x = einsum('b h i j, b h j d -> b h i d', attn, v)
        x = rearrange(x, 'b h n d -> b n (h d)')
        x = self.to_out(x)

        return x


class Transformer(Layer):
    def __init__(self, dim, depth, heads, dim_head, transformer_hidden_units, dropout=0.0):
        super(Transformer, self).__init__()
        self.layers = []

        for _ in range(depth):
            self.layers.append([
                PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(MLP(dim, transformer_hidden_units, dropout=dropout))
            ])

    def call(self, x):
        for attn, mlp in self.layers:
            x = attn(x) + x
            x = mlp(x) + x

        return x


class ViT(Model):
    def __init__(self, image_size, patch_size, num_actuators, dim, depth,
                 transformer_hidden_units, mlp_head_units, heads, dim_head=64,
                 dropout=0.0, emb_dropout=0.0):

        super(ViT, self).__init__()

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        num_patches = (image_height // patch_height) * (image_width // patch_width)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'

        self.patch_embedding = Sequential([
            tf.keras.layers.InputLayer(input_shape=(image_height, image_width, 1)),
            Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Dense(units=dim, input_shape=(num_patches, patch_height*patch_width*1)),
        ], name='patch_embedding')

        self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches, dim]))

        self.dropout = nn.Dropout(rate=emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head,
                                       transformer_hidden_units, dropout)

        # define mlp head
        mlp = [nn.LayerNormalization(), Rearrange('b n d -> b (n d)')]
        for units in mlp_head_units:
            mlp.append(nn.Dense(units=units))
            # mlp.append(nn.ReLU())
            mlp.append(nn.LeakyReLU())
        mlp.append(nn.Dense(units=num_actuators))
        self.mlp_head = Sequential(mlp, name='mlp_head')

    def call(self, img, **kwargs):
        x = self.patch_embedding(img)
        b, n, d = x.shape
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        x = self.transformer(x)

        # x = rearrange(x, 'b n d -> b (n d)')
        x = self.mlp_head(x)

        return x


def get_vision_transformer():
    model = ViT(image_size, patch_size, num_actuators=285, dim=projection_dim,
                depth=transformer_layers, heads=num_heads, mlp_head_units=mlp_head_units,
                transformer_hidden_units=transformer_hidden_units, dim_head=projection_dim,
                dropout=dropout, emb_dropout=emb_dropout)
    model.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=optimizer_initial_lr),
        loss=loss_values,
        metrics=[metrics_values]
    )
    return model

    # class Patches(tf.keras.layers.Layer):
#     def __init__(self, patch_size):
#         super(Patches, self).__init__()
#         self.patch_size = patch_size
#
#     def call(self, images):
#         batch_size = tf.shape(images)[0]
#         patches = tf.image.extract_patches(
#             images=images,
#             sizes=[1, self.patch_size, self.patch_size, 1],
#             strides=[1, self.patch_size, self.patch_size, 1],
#             rates=[1, 1, 1, 1],
#             padding='VALID'
#         )
#         patch_dims = patches.shape[-1]
#         patches = tf.reshape(patches, [batch_size, -1, patch_dims])
#         return patches
#
#
# class PatchEncoder(tf.keras.layers.Layer):
#     def __init__(self, num_patches, projection_dim):
#         super(PatchEncoder, self).__init__()
#         self.num_patches = num_patches
#         self.projection = tf.keras.layers.Dense(projection_dim)
#         self.position_embedding = tf.keras.layers.Embedding(
#             input_dim=num_patches, output_dim=projection_dim
#         )
#
#     def call(self, patch):
#         positions = tf.range(start=0, limit=self.num_patches, delta=1)
#         encoded = self.projection(patch)+self.position_embedding(positions)
#         return encoded
#
#
# def create_vision_transformer():
#     inputs = tf.keras.layers.Input(shape=(image_size, image_size, 1))
#     patches = Patches(patch_size)(inputs)
#     encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
#
#     for _ in range(transformer_layers):
#         x1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
#         attention_output = tf.keras.layers.MultiHeadAttention(
#             num_heads=num_heads, key_dim=projection_dim
#         )(x1, x1)
#         x2 = tf.keras.layers.Add()([attention_output, encoded_patches])
#         x3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x2)
#         x3 = mlp(x3, hidden_units=transformer_units)
#         encoded_patches = tf.keras.layers.Add()([x3, x2])
#     representation = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
#     representation = tf.keras.layers.Flatten()(representation)
#     features = mlp(representation, hidden_units=mlp_head_units)
#     result = tf.keras.layers.Dense(64)(features)
#     model = tf.keras.Model(inputs=inputs, outputs=result)
#     model.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=optimizer_initial_lr),
#         loss=loss_values,
#         metrics=[metrics_values]
#     )
#     return model
